{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    },
    "toc": {
      "nav_menu": {},
      "number_sections": true,
      "sideBar": true,
      "skip_h1_title": false,
      "title_cell": "Table of Contents",
      "title_sidebar": "Contents",
      "toc_cell": false,
      "toc_position": {},
      "toc_section_display": true,
      "toc_window_display": false
    },
    "colab": {
      "name": "cnn-he-init.ipynb",
      "provenance": []
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "da66ef015ea54f7f90cee2e8c1bbaeef": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_fa7c10dfebd748f3ab1f60fd9edc1ead",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_538478dedfdd4bf5b380b534c3edd9cd",
              "IPY_MODEL_a18afb326c254ab4ade48601c62cfa56"
            ]
          }
        },
        "fa7c10dfebd748f3ab1f60fd9edc1ead": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "538478dedfdd4bf5b380b534c3edd9cd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_63181df12a48420eab58e5834e6914f8",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_7159fac77f3446d6af51aa72aee48bcf"
          }
        },
        "a18afb326c254ab4ade48601c62cfa56": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_d53079429a6f4e769e3d58458ebc1e17",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 9920512/? [00:20&lt;00:00, 816909.93it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_1675512b0cac4b0ea667feea501f0d51"
          }
        },
        "63181df12a48420eab58e5834e6914f8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "7159fac77f3446d6af51aa72aee48bcf": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "d53079429a6f4e769e3d58458ebc1e17": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "1675512b0cac4b0ea667feea501f0d51": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f4192eedf7f04a3b9b91f0881380253c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_f75943e348764c1c8722859cb843eefd",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_3ad4fee21a4d4a1d8b828f8678c8de5d",
              "IPY_MODEL_d7db3f78ce454eb59a6aa4bfaed54b46"
            ]
          }
        },
        "f75943e348764c1c8722859cb843eefd": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "3ad4fee21a4d4a1d8b828f8678c8de5d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_26c7d4e4ea6d4d9fb089943de9e58a68",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_2b1a70d5f4824a219d981a942e6b8ce2"
          }
        },
        "d7db3f78ce454eb59a6aa4bfaed54b46": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_9ae38a7ddaf24360a073168f4a3b78a5",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 32768/? [00:01&lt;00:00, 22655.53it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_4071236d659b47fcac517be1056d320b"
          }
        },
        "26c7d4e4ea6d4d9fb089943de9e58a68": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "2b1a70d5f4824a219d981a942e6b8ce2": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "9ae38a7ddaf24360a073168f4a3b78a5": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "4071236d659b47fcac517be1056d320b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "d879d9c146914ef28f5c2f15a11f3c7f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_354c749e239f4d828e2cddaddc70700f",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_88d51334d2b24d16bb797c23f20839c9",
              "IPY_MODEL_eaec91f80e42450099dc20fafed16dbe"
            ]
          }
        },
        "354c749e239f4d828e2cddaddc70700f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "88d51334d2b24d16bb797c23f20839c9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_33f790f8b60247749b03a99c0e405eb2",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_c8ced24de51d4eaa83e3482c371011bb"
          }
        },
        "eaec91f80e42450099dc20fafed16dbe": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_cc5866ff6c084ccebe402eb15a45d82c",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 1654784/? [00:01&lt;00:00, 1456526.07it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_0f64419a21a24d588b006a0ba7167b02"
          }
        },
        "33f790f8b60247749b03a99c0e405eb2": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "c8ced24de51d4eaa83e3482c371011bb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "cc5866ff6c084ccebe402eb15a45d82c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "0f64419a21a24d588b006a0ba7167b02": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "a6481093504d405893de0b5625359cec": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_0674dacd0380411ca2368290343ddba9",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_f8a5e8c6652e4579bd021556227a50a7",
              "IPY_MODEL_6d8f91bcfc4549cba2f08e7b19fbe78c"
            ]
          }
        },
        "0674dacd0380411ca2368290343ddba9": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "f8a5e8c6652e4579bd021556227a50a7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_921b66a2dfd6413393c9928273ea901a",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_cf475058cdff49c483239c3d9a4fdc48"
          }
        },
        "6d8f91bcfc4549cba2f08e7b19fbe78c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_6dd867513895435d9be5c149f5daa2dc",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 8192/? [00:00&lt;00:00, 23717.13it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_6676bf39056841aab12e1b70f37b1054"
          }
        },
        "921b66a2dfd6413393c9928273ea901a": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "cf475058cdff49c483239c3d9a4fdc48": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "6dd867513895435d9be5c149f5daa2dc": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "6676bf39056841aab12e1b70f37b1054": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BX8Y6RyXYDmC",
        "colab_type": "text"
      },
      "source": [
        "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
        "- Author: Sebastian Raschka\n",
        "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8_QI0XyLYI9b",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 86
        },
        "outputId": "da8774e3-04ab-4a31-f4b6-ba73c52a52c3"
      },
      "source": [
        "!pip install -q IPython\n",
        "!pip install -q ipykernel\n",
        "!pip install -q torch\n",
        "!pip install -q watermark\n",
        "!pip install -q matplotlib\n",
        "!pip install -q tensorwatch\n",
        "!pip install -q sklearn\n",
        "!pip install -q pandas\n",
        "!pip install -q pydot\n",
        "!pip install -q hiddenlayer\n",
        "!pip install -q graphviz"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "\u001b[K     |████████████████████████████████| 194kB 8.3MB/s \n",
            "\u001b[K     |████████████████████████████████| 143kB 17.7MB/s \n",
            "\u001b[?25h  Building wheel for tensorwatch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Building wheel for pydotz (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9-IxdDbpYDmD",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        },
        "outputId": "b5dd8d93-8c22-4393-ae1e-a88d16ae86b5"
      },
      "source": [
        "%load_ext watermark\n",
        "%watermark -a 'Sebastian Raschka' -v -p torch"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Sebastian Raschka \n",
            "\n",
            "CPython 3.6.9\n",
            "IPython 5.5.0\n",
            "\n",
            "torch 1.5.1+cu101\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7rAIH-eGYDmI",
        "colab_type": "text"
      },
      "source": [
        "- Runs on CPU or GPU (if available)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HmMzopaPYDmI",
        "colab_type": "text"
      },
      "source": [
        "# Model Zoo -- Convolutional Neural Network with He Initialization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I5Sn9EGUYDmJ",
        "colab_type": "text"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WC7dMQl4YDmJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import time\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torchvision import datasets\n",
        "from torchvision import transforms\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    torch.backends.cudnn.deterministic = True"
      ],
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ETu9bzaPYDmM",
        "colab_type": "text"
      },
      "source": [
        "## Settings and Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TnbydAZYYDmM",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 441,
          "referenced_widgets": [
            "da66ef015ea54f7f90cee2e8c1bbaeef",
            "fa7c10dfebd748f3ab1f60fd9edc1ead",
            "538478dedfdd4bf5b380b534c3edd9cd",
            "a18afb326c254ab4ade48601c62cfa56",
            "63181df12a48420eab58e5834e6914f8",
            "7159fac77f3446d6af51aa72aee48bcf",
            "d53079429a6f4e769e3d58458ebc1e17",
            "1675512b0cac4b0ea667feea501f0d51",
            "f4192eedf7f04a3b9b91f0881380253c",
            "f75943e348764c1c8722859cb843eefd",
            "3ad4fee21a4d4a1d8b828f8678c8de5d",
            "d7db3f78ce454eb59a6aa4bfaed54b46",
            "26c7d4e4ea6d4d9fb089943de9e58a68",
            "2b1a70d5f4824a219d981a942e6b8ce2",
            "9ae38a7ddaf24360a073168f4a3b78a5",
            "4071236d659b47fcac517be1056d320b",
            "d879d9c146914ef28f5c2f15a11f3c7f",
            "354c749e239f4d828e2cddaddc70700f",
            "88d51334d2b24d16bb797c23f20839c9",
            "eaec91f80e42450099dc20fafed16dbe",
            "33f790f8b60247749b03a99c0e405eb2",
            "c8ced24de51d4eaa83e3482c371011bb",
            "cc5866ff6c084ccebe402eb15a45d82c",
            "0f64419a21a24d588b006a0ba7167b02",
            "a6481093504d405893de0b5625359cec",
            "0674dacd0380411ca2368290343ddba9",
            "f8a5e8c6652e4579bd021556227a50a7",
            "6d8f91bcfc4549cba2f08e7b19fbe78c",
            "921b66a2dfd6413393c9928273ea901a",
            "cf475058cdff49c483239c3d9a4fdc48",
            "6dd867513895435d9be5c149f5daa2dc",
            "6676bf39056841aab12e1b70f37b1054"
          ]
        },
        "outputId": "b8a6eaca-6689-4f84-db15-6912df7ce554"
      },
      "source": [
        "##########################\n",
        "### SETTINGS\n",
        "##########################\n",
        "\n",
        "# Device\n",
        "device = torch.device(\"cuda:2\" if torch.cuda.is_available() else \"cpu\")\n",
        "\n",
        "# Hyperparameters\n",
        "random_seed = 1\n",
        "learning_rate = 0.05\n",
        "num_epochs = 10\n",
        "batch_size = 128\n",
        "\n",
        "# Architecture\n",
        "num_classes = 10\n",
        "\n",
        "\n",
        "##########################\n",
        "### MNIST DATASET\n",
        "##########################\n",
        "\n",
        "# Note transforms.ToTensor() scales input images\n",
        "# to 0-1 range\n",
        "train_dataset = datasets.MNIST(root='data', \n",
        "                               train=True, \n",
        "                               transform=transforms.ToTensor(),\n",
        "                               download=True)\n",
        "\n",
        "test_dataset = datasets.MNIST(root='data', \n",
        "                              train=False, \n",
        "                              transform=transforms.ToTensor())\n",
        "\n",
        "\n",
        "train_loader = DataLoader(dataset=train_dataset, \n",
        "                          batch_size=batch_size, \n",
        "                          shuffle=True)\n",
        "\n",
        "test_loader = DataLoader(dataset=test_dataset, \n",
        "                         batch_size=batch_size, \n",
        "                         shuffle=False)\n",
        "\n",
        "# Checking the dataset\n",
        "for images, labels in train_loader:  \n",
        "    print('Image batch dimensions:', images.shape)\n",
        "    print('Image label dimensions:', labels.shape)\n",
        "    break"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "da66ef015ea54f7f90cee2e8c1bbaeef",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f4192eedf7f04a3b9b91f0881380253c",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "d879d9c146914ef28f5c2f15a11f3c7f",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "a6481093504d405893de0b5625359cec",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw\n",
            "Processing...\n",
            "Done!\n",
            "\n",
            "\n",
            "\n",
            "Image batch dimensions: torch.Size([128, 1, 28, 28])\n",
            "Image label dimensions: torch.Size([128])\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/pytorch/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iG2RER_sYDmP",
        "colab_type": "text"
      },
      "source": [
        "## Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3rTa9NC_YDmP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "##########################\n",
        "### MODEL\n",
        "##########################\n",
        "\n",
        "\n",
        "class ConvNet(torch.nn.Module):\n",
        "\n",
        "    def __init__(self, num_classes):\n",
        "        super(ConvNet, self).__init__()\n",
        "        \n",
        "        # calculate same padding:\n",
        "        # (w - k + 2*p)/s + 1 = o\n",
        "        # => p = (s(o-1) - w + k)/2\n",
        "        \n",
        "        # 28x28x1 => 28x28x4\n",
        "        self.conv_1 = torch.nn.Conv2d(in_channels=1,\n",
        "                                      out_channels=4,\n",
        "                                      kernel_size=(3, 3),\n",
        "                                      stride=(1, 1),\n",
        "                                      padding=1) # (1(28-1) - 28 + 3) / 2 = 1\n",
        "        # 28x28x4 => 14x14x4\n",
        "        self.pool_1 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
        "                                         stride=(2, 2),\n",
        "                                         padding=0) # (2(14-1) - 28 + 2) = 0                                       \n",
        "        # 14x14x4 => 14x14x8\n",
        "        self.conv_2 = torch.nn.Conv2d(in_channels=4,\n",
        "                                      out_channels=8,\n",
        "                                      kernel_size=(3, 3),\n",
        "                                      stride=(1, 1),\n",
        "                                      padding=1) # (1(14-1) - 14 + 3) / 2 = 1                 \n",
        "        # 14x14x8 => 7x7x8                             \n",
        "        self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
        "                                         stride=(2, 2),\n",
        "                                         padding=0) # (2(7-1) - 14 + 2) = 0\n",
        "        \n",
        "        self.linear_1 = torch.nn.Linear(7*7*8, num_classes)\n",
        "        \n",
        "        ###############################################\n",
        "        # Reinitialize weights using He initialization\n",
        "        ###############################################\n",
        "        for m in self.modules():\n",
        "            if isinstance(m, torch.nn.Conv2d):\n",
        "                nn.init.kaiming_normal_(m.weight.detach())\n",
        "                m.bias.detach().zero_()\n",
        "            elif isinstance(m, torch.nn.Linear):\n",
        "                nn.init.kaiming_normal_(m.weight.detach())\n",
        "                m.bias.detach().zero_()\n",
        "        \n",
        "    def forward(self, x):\n",
        "        out = self.conv_1(x)\n",
        "        out = F.relu(out)\n",
        "        out = self.pool_1(out)\n",
        "\n",
        "        out = self.conv_2(out)\n",
        "        out = F.relu(out)\n",
        "        out = self.pool_2(out)\n",
        "        \n",
        "        logits = self.linear_1(out.view(-1, 7*7*8))\n",
        "        probas = F.softmax(logits, dim=1)\n",
        "        return logits, probas\n",
        "\n",
        "    \n",
        "torch.manual_seed(random_seed)\n",
        "model = ConvNet(num_classes=num_classes)\n",
        "\n",
        "model = model.to(device)\n",
        "\n",
        "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)  "
      ],
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wXAw4GyZYONT",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 257
        },
        "outputId": "07f9de01-a0f4-460d-a54c-151dfaa0a78b"
      },
      "source": [
        "import hiddenlayer as hl\n",
        "hl.build_graph(model, torch.zeros([128, 1, 28, 28]).to(device))"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<hiddenlayer.graph.Graph at 0x7f67f3d73b70>"
            ],
            "image/svg+xml": "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n<!-- Generated by graphviz version 2.40.1 (20161225.0304)\n -->\n<!-- Title: %3 Pages: 1 -->\n<svg width=\"1121pt\" height=\"162pt\"\n viewBox=\"0.00 0.00 1121.00 162.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(72 126)\">\n<title>%3</title>\n<polygon fill=\"#ffffff\" stroke=\"transparent\" points=\"-72,36 -72,-126 1049,-126 1049,36 -72,36\"/>\n<!-- /outputs/9 -->\n<g id=\"node1\" class=\"node\">\n<title>/outputs/9</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"248,-36 178,-36 178,0 248,0 248,-36\"/>\n<text text-anchor=\"start\" x=\"186\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MaxPool2x2</text>\n</g>\n<!-- 10660650273425428469 -->\n<g id=\"node8\" class=\"node\">\n<title>10660650273425428469</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"426,-36 342,-36 342,0 426,0 426,-36\"/>\n<text text-anchor=\"start\" x=\"350\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Conv3x3 &gt; Relu</text>\n</g>\n<!-- /outputs/9&#45;&gt;10660650273425428469 -->\n<g id=\"edge6\" class=\"edge\">\n<title>/outputs/9&#45;&gt;10660650273425428469</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M248.017,-18C272.1037,-18 304.5182,-18 331.6582,-18\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"331.7732,-21.5001 341.7732,-18 331.7732,-14.5001 331.7732,-21.5001\"/>\n<text text-anchor=\"middle\" x=\"295\" y=\"-21\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x4x14x14</text>\n</g>\n<!-- /outputs/12 -->\n<g id=\"node2\" class=\"node\">\n<title>/outputs/12</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"590,-36 520,-36 520,0 590,0 590,-36\"/>\n<text text-anchor=\"start\" x=\"528\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">MaxPool2x2</text>\n</g>\n<!-- /outputs/14 -->\n<g id=\"node4\" class=\"node\">\n<title>/outputs/14</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"728,-63 674,-63 674,-27 728,-27 728,-63\"/>\n<text text-anchor=\"start\" x=\"684\" y=\"-42\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Reshape</text>\n</g>\n<!-- /outputs/12&#45;&gt;/outputs/14 -->\n<g id=\"edge1\" class=\"edge\">\n<title>/outputs/12&#45;&gt;/outputs/14</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M590.3425,-24.5359C612.5906,-28.6503 641.2648,-33.9531 663.8493,-38.1297\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"663.2828,-41.5842 673.7525,-39.9611 664.5558,-34.7009 663.2828,-41.5842\"/>\n<text text-anchor=\"middle\" x=\"632\" y=\"-39\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x8x7x7</text>\n</g>\n<!-- /outputs/13 -->\n<g id=\"node3\" class=\"node\">\n<title>/outputs/13</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"582,-90 528,-90 528,-54 582,-54 582,-90\"/>\n<text text-anchor=\"start\" x=\"537\" y=\"-69\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Constant</text>\n</g>\n<!-- /outputs/13&#45;&gt;/outputs/14 -->\n<g id=\"edge2\" class=\"edge\">\n<title>/outputs/13&#45;&gt;/outputs/14</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M582.1193,-66.9848C605.0227,-62.7492 638.0591,-56.6398 663.4958,-51.9357\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"664.3337,-55.3402 673.5305,-50.08 663.0607,-48.4569 664.3337,-55.3402\"/>\n</g>\n<!-- /outputs/15 -->\n<g id=\"node5\" class=\"node\">\n<title>/outputs/15</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"855,-63 801,-63 801,-27 855,-27 855,-63\"/>\n<text text-anchor=\"start\" x=\"815\" y=\"-42\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Linear</text>\n</g>\n<!-- /outputs/14&#45;&gt;/outputs/15 -->\n<g id=\"edge3\" class=\"edge\">\n<title>/outputs/14&#45;&gt;/outputs/15</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M728.2447,-45C746.492,-45 770.746,-45 790.7693,-45\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"790.8004,-48.5001 800.8004,-45 790.8003,-41.5001 790.8004,-48.5001\"/>\n<text text-anchor=\"middle\" x=\"764.5\" y=\"-48\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x392</text>\n</g>\n<!-- /outputs/16 -->\n<g id=\"node6\" class=\"node\">\n<title>/outputs/16</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"977,-63 923,-63 923,-27 977,-27 977,-63\"/>\n<text text-anchor=\"start\" x=\"933\" y=\"-42\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Softmax</text>\n</g>\n<!-- /outputs/15&#45;&gt;/outputs/16 -->\n<g id=\"edge4\" class=\"edge\">\n<title>/outputs/15&#45;&gt;/outputs/16</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M855.0758,-45C872.0553,-45 894.1767,-45 912.7924,-45\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"913,-48.5001 922.9999,-45 912.9999,-41.5001 913,-48.5001\"/>\n<text text-anchor=\"middle\" x=\"889\" y=\"-48\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x10</text>\n</g>\n<!-- 9620407494180394767 -->\n<g id=\"node7\" class=\"node\">\n<title>9620407494180394767</title>\n<polygon fill=\"#e8e8e8\" stroke=\"#000000\" points=\"84,-36 0,-36 0,0 84,0 84,-36\"/>\n<text text-anchor=\"start\" x=\"8\" y=\"-15\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">Conv3x3 &gt; Relu</text>\n</g>\n<!-- 9620407494180394767&#45;&gt;/outputs/9 -->\n<g id=\"edge5\" class=\"edge\">\n<title>9620407494180394767&#45;&gt;/outputs/9</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M84.2697,-18C109.5651,-18 141.7113,-18 167.5519,-18\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"167.8392,-21.5001 177.8391,-18 167.8391,-14.5001 167.8392,-21.5001\"/>\n<text text-anchor=\"middle\" x=\"131\" y=\"-21\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x4x28x28</text>\n</g>\n<!-- 10660650273425428469&#45;&gt;/outputs/12 -->\n<g id=\"edge7\" class=\"edge\">\n<title>10660650273425428469&#45;&gt;/outputs/12</title>\n<path fill=\"none\" stroke=\"#000000\" d=\"M426.2697,-18C451.5651,-18 483.7113,-18 509.5519,-18\"/>\n<polygon fill=\"#000000\" stroke=\"#000000\" points=\"509.8392,-21.5001 519.8391,-18 509.8391,-14.5001 509.8392,-21.5001\"/>\n<text text-anchor=\"middle\" x=\"473\" y=\"-21\" font-family=\"Times\" font-size=\"10.00\" fill=\"#000000\">128x8x14x14</text>\n</g>\n</g>\n</svg>\n"
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kn7msM6mYDmS",
        "colab_type": "text"
      },
      "source": [
        "## Training"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Pv35yuJzYDmS",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "67a7b35f-5c6a-4fd9-c2ce-18e24b97cde5"
      },
      "source": [
        "def compute_accuracy(model, data_loader):\n",
        "    correct_pred, num_examples = 0, 0\n",
        "    for features, targets in data_loader:\n",
        "        features = features.to(device)\n",
        "        targets = targets.to(device)\n",
        "        logits, probas = model(features)\n",
        "        _, predicted_labels = torch.max(probas, 1)\n",
        "        num_examples += targets.size(0)\n",
        "        correct_pred += (predicted_labels == targets).sum()\n",
        "    return correct_pred.float()/num_examples * 100\n",
        "    \n",
        "\n",
        "start_time = time.time()\n",
        "for epoch in range(num_epochs):\n",
        "    model = model.train()\n",
        "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
        "        \n",
        "        features = features.to(device)\n",
        "        targets = targets.to(device)\n",
        "\n",
        "        ### FORWARD AND BACK PROP\n",
        "        logits, probas = model(features)\n",
        "        cost = F.cross_entropy(logits, targets)\n",
        "        optimizer.zero_grad()\n",
        "        \n",
        "        cost.backward()\n",
        "        \n",
        "        ### UPDATE MODEL PARAMETERS\n",
        "        optimizer.step()\n",
        "        \n",
        "        ### LOGGING\n",
        "        if not batch_idx % 50:\n",
        "            print ('Epoch: %03d/%03d | Batch %03d/%03d | Cost: %.4f' \n",
        "                   %(epoch+1, num_epochs, batch_idx, \n",
        "                     len(train_loader), cost))\n",
        "    \n",
        "    model = model.eval()\n",
        "    print('Epoch: %03d/%03d training accuracy: %.2f%%' % (\n",
        "          epoch+1, num_epochs, \n",
        "          compute_accuracy(model, train_loader)))\n",
        "    \n",
        "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
        "    \n",
        "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 001/010 | Batch 000/469 | Cost: 2.4577\n",
            "Epoch: 001/010 | Batch 050/469 | Cost: 1.1068\n",
            "Epoch: 001/010 | Batch 100/469 | Cost: 0.6610\n",
            "Epoch: 001/010 | Batch 150/469 | Cost: 0.5354\n",
            "Epoch: 001/010 | Batch 200/469 | Cost: 0.4479\n",
            "Epoch: 001/010 | Batch 250/469 | Cost: 0.3159\n",
            "Epoch: 001/010 | Batch 300/469 | Cost: 0.4545\n",
            "Epoch: 001/010 | Batch 350/469 | Cost: 0.4277\n",
            "Epoch: 001/010 | Batch 400/469 | Cost: 0.1386\n",
            "Epoch: 001/010 | Batch 450/469 | Cost: 0.1410\n",
            "Epoch: 001/010 training accuracy: 91.96%\n",
            "Time elapsed: 0.34 min\n",
            "Epoch: 002/010 | Batch 000/469 | Cost: 0.2198\n",
            "Epoch: 002/010 | Batch 050/469 | Cost: 0.1464\n",
            "Epoch: 002/010 | Batch 100/469 | Cost: 0.2627\n",
            "Epoch: 002/010 | Batch 150/469 | Cost: 0.1919\n",
            "Epoch: 002/010 | Batch 200/469 | Cost: 0.1486\n",
            "Epoch: 002/010 | Batch 250/469 | Cost: 0.1228\n",
            "Epoch: 002/010 | Batch 300/469 | Cost: 0.1591\n",
            "Epoch: 002/010 | Batch 350/469 | Cost: 0.1410\n",
            "Epoch: 002/010 | Batch 400/469 | Cost: 0.1405\n",
            "Epoch: 002/010 | Batch 450/469 | Cost: 0.1210\n",
            "Epoch: 002/010 training accuracy: 95.21%\n",
            "Time elapsed: 0.67 min\n",
            "Epoch: 003/010 | Batch 000/469 | Cost: 0.1288\n",
            "Epoch: 003/010 | Batch 050/469 | Cost: 0.2470\n",
            "Epoch: 003/010 | Batch 100/469 | Cost: 0.1310\n",
            "Epoch: 003/010 | Batch 150/469 | Cost: 0.1890\n",
            "Epoch: 003/010 | Batch 200/469 | Cost: 0.1053\n",
            "Epoch: 003/010 | Batch 250/469 | Cost: 0.1565\n",
            "Epoch: 003/010 | Batch 300/469 | Cost: 0.1235\n",
            "Epoch: 003/010 | Batch 350/469 | Cost: 0.1388\n",
            "Epoch: 003/010 | Batch 400/469 | Cost: 0.1556\n",
            "Epoch: 003/010 | Batch 450/469 | Cost: 0.1658\n",
            "Epoch: 003/010 training accuracy: 96.44%\n",
            "Time elapsed: 1.00 min\n",
            "Epoch: 004/010 | Batch 000/469 | Cost: 0.1829\n",
            "Epoch: 004/010 | Batch 050/469 | Cost: 0.0611\n",
            "Epoch: 004/010 | Batch 100/469 | Cost: 0.1969\n",
            "Epoch: 004/010 | Batch 150/469 | Cost: 0.1072\n",
            "Epoch: 004/010 | Batch 200/469 | Cost: 0.1061\n",
            "Epoch: 004/010 | Batch 250/469 | Cost: 0.0967\n",
            "Epoch: 004/010 | Batch 300/469 | Cost: 0.0593\n",
            "Epoch: 004/010 | Batch 350/469 | Cost: 0.1031\n",
            "Epoch: 004/010 | Batch 400/469 | Cost: 0.1504\n",
            "Epoch: 004/010 | Batch 450/469 | Cost: 0.1620\n",
            "Epoch: 004/010 training accuracy: 96.62%\n",
            "Time elapsed: 1.33 min\n",
            "Epoch: 005/010 | Batch 000/469 | Cost: 0.0470\n",
            "Epoch: 005/010 | Batch 050/469 | Cost: 0.0350\n",
            "Epoch: 005/010 | Batch 100/469 | Cost: 0.1233\n",
            "Epoch: 005/010 | Batch 150/469 | Cost: 0.0433\n",
            "Epoch: 005/010 | Batch 200/469 | Cost: 0.1049\n",
            "Epoch: 005/010 | Batch 250/469 | Cost: 0.1131\n",
            "Epoch: 005/010 | Batch 300/469 | Cost: 0.2228\n",
            "Epoch: 005/010 | Batch 350/469 | Cost: 0.1272\n",
            "Epoch: 005/010 | Batch 400/469 | Cost: 0.1405\n",
            "Epoch: 005/010 | Batch 450/469 | Cost: 0.0651\n",
            "Epoch: 005/010 training accuracy: 97.22%\n",
            "Time elapsed: 1.66 min\n",
            "Epoch: 006/010 | Batch 000/469 | Cost: 0.0885\n",
            "Epoch: 006/010 | Batch 050/469 | Cost: 0.1361\n",
            "Epoch: 006/010 | Batch 100/469 | Cost: 0.1085\n",
            "Epoch: 006/010 | Batch 150/469 | Cost: 0.0794\n",
            "Epoch: 006/010 | Batch 200/469 | Cost: 0.0817\n",
            "Epoch: 006/010 | Batch 250/469 | Cost: 0.1873\n",
            "Epoch: 006/010 | Batch 300/469 | Cost: 0.1786\n",
            "Epoch: 006/010 | Batch 350/469 | Cost: 0.1109\n",
            "Epoch: 006/010 | Batch 400/469 | Cost: 0.1057\n",
            "Epoch: 006/010 | Batch 450/469 | Cost: 0.0737\n",
            "Epoch: 006/010 training accuracy: 97.23%\n",
            "Time elapsed: 1.99 min\n",
            "Epoch: 007/010 | Batch 000/469 | Cost: 0.1301\n",
            "Epoch: 007/010 | Batch 050/469 | Cost: 0.0940\n",
            "Epoch: 007/010 | Batch 100/469 | Cost: 0.0863\n",
            "Epoch: 007/010 | Batch 150/469 | Cost: 0.1714\n",
            "Epoch: 007/010 | Batch 200/469 | Cost: 0.0842\n",
            "Epoch: 007/010 | Batch 250/469 | Cost: 0.0879\n",
            "Epoch: 007/010 | Batch 300/469 | Cost: 0.0567\n",
            "Epoch: 007/010 | Batch 350/469 | Cost: 0.0805\n",
            "Epoch: 007/010 | Batch 400/469 | Cost: 0.0785\n",
            "Epoch: 007/010 | Batch 450/469 | Cost: 0.1236\n",
            "Epoch: 007/010 training accuracy: 97.47%\n",
            "Time elapsed: 2.32 min\n",
            "Epoch: 008/010 | Batch 000/469 | Cost: 0.0739\n",
            "Epoch: 008/010 | Batch 050/469 | Cost: 0.0674\n",
            "Epoch: 008/010 | Batch 100/469 | Cost: 0.1882\n",
            "Epoch: 008/010 | Batch 150/469 | Cost: 0.0757\n",
            "Epoch: 008/010 | Batch 200/469 | Cost: 0.0633\n",
            "Epoch: 008/010 | Batch 250/469 | Cost: 0.1169\n",
            "Epoch: 008/010 | Batch 300/469 | Cost: 0.0312\n",
            "Epoch: 008/010 | Batch 350/469 | Cost: 0.0825\n",
            "Epoch: 008/010 | Batch 400/469 | Cost: 0.1266\n",
            "Epoch: 008/010 | Batch 450/469 | Cost: 0.0489\n",
            "Epoch: 008/010 training accuracy: 97.52%\n",
            "Time elapsed: 2.65 min\n",
            "Epoch: 009/010 | Batch 000/469 | Cost: 0.0537\n",
            "Epoch: 009/010 | Batch 050/469 | Cost: 0.1868\n",
            "Epoch: 009/010 | Batch 100/469 | Cost: 0.0639\n",
            "Epoch: 009/010 | Batch 150/469 | Cost: 0.0393\n",
            "Epoch: 009/010 | Batch 200/469 | Cost: 0.0664\n",
            "Epoch: 009/010 | Batch 250/469 | Cost: 0.0874\n",
            "Epoch: 009/010 | Batch 300/469 | Cost: 0.1967\n",
            "Epoch: 009/010 | Batch 350/469 | Cost: 0.0718\n",
            "Epoch: 009/010 | Batch 400/469 | Cost: 0.0789\n",
            "Epoch: 009/010 | Batch 450/469 | Cost: 0.0224\n",
            "Epoch: 009/010 training accuracy: 97.89%\n",
            "Time elapsed: 2.98 min\n",
            "Epoch: 010/010 | Batch 000/469 | Cost: 0.0987\n",
            "Epoch: 010/010 | Batch 050/469 | Cost: 0.0768\n",
            "Epoch: 010/010 | Batch 100/469 | Cost: 0.1977\n",
            "Epoch: 010/010 | Batch 150/469 | Cost: 0.0396\n",
            "Epoch: 010/010 | Batch 200/469 | Cost: 0.0342\n",
            "Epoch: 010/010 | Batch 250/469 | Cost: 0.0538\n",
            "Epoch: 010/010 | Batch 300/469 | Cost: 0.1165\n",
            "Epoch: 010/010 | Batch 350/469 | Cost: 0.1022\n",
            "Epoch: 010/010 | Batch 400/469 | Cost: 0.1557\n",
            "Epoch: 010/010 | Batch 450/469 | Cost: 0.1762\n",
            "Epoch: 010/010 training accuracy: 97.79%\n",
            "Time elapsed: 3.32 min\n",
            "Total Training Time: 3.32 min\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h13EuRFJYDmV",
        "colab_type": "text"
      },
      "source": [
        "## Evaluation"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yPp2t-uGYDmV",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        },
        "outputId": "e03ef143-bc9d-4f66-916d-d9a47aad5610"
      },
      "source": [
        "print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader)))"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Test accuracy: 97.65%\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-xfQJszfYDmY",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 69
        },
        "outputId": "574434b9-05f5-47e3-d8a5-a46a22aa4106"
      },
      "source": [
        "%watermark -iv"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "numpy 1.18.5\n",
            "torch 1.5.1+cu101\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}